package fun.codedesign.jdk.math.algorithm.nlp;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * TextRank算法实现关键词抽取 <br>
 * 让每一个单词给自己身边的单词投票，得到票数最多的其权重越高，但如果是“的”的结构助词，那就可能需要进行过滤 <br>
 *
 * @author zengjian
 * @create 2018-06-05 14:24
 * @since 1.0.0
 */
public class TextRank {

    /**
     * 预设定的窗口长度
     */
    private static final Integer WINDOWS_LENGTH = 3;
    /**
     * 默认的阻尼系数，则默认权重为 1-0.85
     */
    private static final Double INIT_WEIGHT = 0.85;

    /**
     * 初始化权重
     */
    private Map<String, Double> initWeightMap = new HashMap<>();

    /**
     * 最终权重
     */
    private Map<String, Double> finalWeightMap = new HashMap<>();

    /**
     * 更新权重迭代次数
     */
    private static final int MAX_ITER = 10;

    /**
     * 最小权重差
     */
    private static final double MIN_DIFF = 0.02;

    /**
     * 计算权重
     * <pre>
     * 设定一个固定长度的窗口进行记录计算
     * 每两个词之间的连接初始权重都是1
     * 设定一个阻尼系数0.85来进行权重计算
     * 某个词连接的词越多，其权重越重，这里其实还有一个问题，如果连接的词实际上是相同的词，那么权重计算就少了一些
     * </pre>
     *
     * @param context 分好词的文本
     * @return
     */
    public Map<String, Double> rank(String context) {
        Map<String, Set<String>> wordMap = new HashMap<>();
        String[] words = context.split(" ");
        for (int i = 0, length = words.length; i < length; i++) {
            // 根据窗口的大小来保存关联的word
            int lower = Math.max(0, i - WINDOWS_LENGTH);
            int upper = Math.min(length, i + WINDOWS_LENGTH);
            // 将词左边的部分加入到词words[i]的连出集合中
            for (int j = lower; j < i; j++) {
                putMap(wordMap, words, i, j);
            }

            // 将词右边的部分加入到词i的连出集合中
            for (int j = i + 1; j < upper; j++) {
                putMap(wordMap, words, i, j);
            }
        }
        // 根据获得wordMap进行权重的计算，WS(Vi) = (1-d) + E(d/(VjOutsize) WS(Vj))(包含连出节点i)
        for (int i = 0; i < MAX_ITER; i++) {
            double max_diff = 0;
            for (Map.Entry<String, Set<String>> wordEntry : wordMap.entrySet()) {
                String word = wordEntry.getKey();
                Set<String> outWords = wordEntry.getValue();
                for (String outWord : outWords) {
                    int outSize = wordMap.get(outWord).size();
                    // 遍历连出节点，如果是word本身或者是该key没有连出节点，那么进行下一个循环，权重不更新
                    if (outWord.equals(word) || outSize == 0) {
                        continue;
                    }
                    // 循环加权重
                    initWeightMap.put(word, initWeightMap.get(word) + INIT_WEIGHT / outSize * (finalWeightMap.get(outWord) == null ? 0 : finalWeightMap.get(outWord)));
                }
                // 计算迭代前后两次的变化
                max_diff = Math.max(max_diff, Math.abs(initWeightMap.get(word)-finalWeightMap.get(word)));
            }
            finalWeightMap = initWeightMap;
            // TODO 这里是有一个最小差异的假设值，并且迭代的次数实际上也待定
            if (max_diff<= MIN_DIFF){
                break;
            }
        }
        return finalWeightMap;
    }

    private void putMap(Map<String, Set<String>> wordMap, String[] words, int i, int j) {
        String word = words[i];
        if (!wordMap.containsKey(word)) {
            Set<String> set = new HashSet<>();
            set.add(words[j]);
            wordMap.put(word, set);
            initWeightMap.put(word, 1 - INIT_WEIGHT);
            finalWeightMap.put(word, 1 - INIT_WEIGHT);
        } else {
            wordMap.get(word).add(words[j]);
            initWeightMap.put(word, 1 - INIT_WEIGHT);
            finalWeightMap.put(word, 1 - INIT_WEIGHT);
        }
    }
}
